#! /home/jyp/.miniconda3/envs/lvio_fusion/bin/python
# -*- coding: utf-8 -*-

from os import stat
import rospy
import sys
import select
import tty
import termios
from lvio_fusion_node.srv import *
from rl_fusion.td3 import *
from rl_fusion.env import LvioFusionEnv

state = 0  # 0: have not started; 1: start training; 2: start evaluation


def load():
    global state
    if state == 0:
        state = 2
        rospy.logwarn("Loading test_td3.")
        load_td3()
        rospy.logwarn("Loaded test_td3.")


def init_callback(req):
    global state
    if state == 0:
        state = 1
        rospy.logwarn("Start test_td3.")
        train_td3()


def update_callback(req):
    global state
    if state == 2:
        weights = get_weights(req.obs)
        rospy.logwarn("Weights: %f, %f, %f" %
                      (weights[0], weights[1], weights[2]))
        return UpdateWeightsResponse(weights[0], weights[1], weights[2])


def get_key():
    tty.setraw(sys.stdin.fileno())
    rlist, _, _ = select.select([sys.stdin], [], [], None)
    if rlist:
        key = sys.stdin.read(1)
    else:
        key = ''
    termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings)
    return key


if __name__ == '__main__':
    try:
        rospy.init_node('rl_fusion_node')
        LvioFusionEnv.obs_rows = rospy.get_param('~obs_rows')
        LvioFusionEnv.obs_cols = rospy.get_param('~obs_cols')
        mode = rospy.get_param('~mode')
        server_init = rospy.Service(
            '/lvio_fusion_node/init', Init, init_callback)
        server_update_weights = rospy.Service(
            '/lvio_fusion_node/update_weights', UpdateWeights, update_callback)
        LvioFusionEnv.client_create_env = rospy.ServiceProxy(
            '/lvio_fusion_node/create_env', CreateEnv)
        LvioFusionEnv.client_step = rospy.ServiceProxy(
            '/lvio_fusion_node/step', Step)

        if mode == 2:
            load()

        settings = termios.tcgetattr(sys.stdin)
        rate = rospy.Rate(100)
        while(1):
            key = get_key()
            if key == 'q':
                if stop == True:
                    break
                stop = True
            rate.sleep()
        termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings)
    except rospy.ROSInterruptException:
        pass
